/*
 * Copyright (C) 2016 Square, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package okio

import java.io.IOException
import java.io.InterruptedIOException
import java.util.Random
import java.util.concurrent.TimeUnit
import okio.ByteString.Companion.decodeHex
import okio.HashingSink.Companion.sha1
import okio.TestUtil.assumeNotWindows
import okio.TestingExecutors.newScheduledExecutorService
import org.junit.After
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Assert.fail
import org.junit.Test

class PipeTest {
  private val executorService = newScheduledExecutorService(2)

  @After
  fun tearDown() {
    executorService.shutdown()
  }

  @Test
  fun test() {
    val pipe = Pipe(6)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3L)
    val source = pipe.source
    val readBuffer = Buffer()
    assertEquals(3L, source.read(readBuffer, 6L))
    assertEquals("abc", readBuffer.readUtf8())
    pipe.sink.close()
    assertEquals(-1L, source.read(readBuffer, 6L))
    source.close()
  }

  /**
   * A producer writes the first 16 MiB of bytes generated by `new Random(0)` to a sink, and a
   * consumer consumes them. Both compute hashes of their data to confirm that they're as expected.
   */
  @Test
  fun largeDataset() {
    val pipe = Pipe(1000L) // An awkward size to force producer/consumer exchange.
    val totalBytes = 16L * 1024L * 1024L
    val expectedHash = "7c3b224bea749086babe079360cf29f98d88262d".decodeHex()

    // Write data to the sink.
    val sinkHash = executorService.submit<ByteString> {
      val hashingSink = sha1(pipe.sink)
      val random = Random(0)
      val data = ByteArray(8192)
      val buffer = Buffer()
      var i = 0L
      while (i < totalBytes) {
        random.nextBytes(data)
        buffer.write(data)
        hashingSink.write(buffer, buffer.size)
        i += data.size.toLong()
      }
      hashingSink.close()
      hashingSink.hash
    }

    // Read data from the source.
    val sourceHash = executorService.submit<ByteString> {
      val blackhole = Buffer()
      val hashingSink = sha1(blackhole)
      val buffer = Buffer()
      while (pipe.source.read(buffer, Long.MAX_VALUE) != -1L) {
        hashingSink.write(buffer, buffer.size)
        blackhole.clear()
      }
      pipe.source.close()
      hashingSink.hash
    }
    assertEquals(expectedHash, sinkHash.get())
    assertEquals(expectedHash, sourceHash.get())
  }

  @Test
  fun sinkTimeout() {
    assumeNotWindows()
    val pipe = Pipe(3)
    pipe.sink.timeout().timeout(1000, TimeUnit.MILLISECONDS)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3L)
    val start = now()
    try {
      pipe.sink.write(Buffer().writeUtf8("def"), 3L)
      fail()
    } catch (expected: InterruptedIOException) {
      assertEquals("timeout", expected.message)
    }
    assertElapsed(1000.0, start)
    val readBuffer = Buffer()
    assertEquals(3L, pipe.source.read(readBuffer, 6L))
    assertEquals("abc", readBuffer.readUtf8())
  }

  @Test
  fun sourceTimeout() {
    assumeNotWindows()
    val pipe = Pipe(3L)
    pipe.source.timeout().timeout(1000, TimeUnit.MILLISECONDS)
    val start = now()
    val readBuffer = Buffer()
    try {
      pipe.source.read(readBuffer, 6L)
      fail()
    } catch (expected: InterruptedIOException) {
      assertEquals("timeout", expected.message)
    }
    assertElapsed(1000.0, start)
    assertEquals(0, readBuffer.size)
  }

  /**
   * The writer is writing 12 bytes as fast as it can to a 3 byte buffer. The reader alternates
   * sleeping 1000 ms, then reading 3 bytes. That should make for an approximate timeline like
   * this:
   *
   * ```
   *    0: writer writes 'abc', blocks 0: reader sleeps until 1000
   * 1000: reader reads 'abc', sleeps until 2000
   * 1000: writer writes 'def', blocks
   * 2000: reader reads 'def', sleeps until 3000
   * 2000: writer writes 'ghi', blocks
   * 3000: reader reads 'ghi', sleeps until 4000
   * 3000: writer writes 'jkl', returns
   * 4000: reader reads 'jkl', returns
   * ```
   *
   *
   * Because the writer is writing to a buffer, it finishes before the reader does.
   */
  @Test
  fun sinkBlocksOnSlowReader() {
    val pipe = Pipe(3L)
    executorService.execute {
      val buffer = Buffer()
      Thread.sleep(1000L)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("abc", buffer.readUtf8())
      Thread.sleep(1000L)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("def", buffer.readUtf8())
      Thread.sleep(1000L)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("ghi", buffer.readUtf8())
      Thread.sleep(1000L)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("jkl", buffer.readUtf8())
    }
    val start = now()
    pipe.sink.write(Buffer().writeUtf8("abcdefghijkl"), 12)
    assertElapsed(3000.0, start)
  }

  @Test
  fun sinkWriteFailsByClosedReader() {
    val pipe = Pipe(3L)
    executorService.schedule(
      {
        pipe.source.close()
      },
      1000,
      TimeUnit.MILLISECONDS,
    )
    val start = now()
    try {
      pipe.sink.write(Buffer().writeUtf8("abcdef"), 6)
      fail()
    } catch (expected: IOException) {
      assertEquals("source is closed", expected.message)
      assertElapsed(1000.0, start)
    }
  }

  @Test
  fun sinkFlushDoesntWaitForReader() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.sink.flush()
    val bufferedSource = pipe.source.buffer()
    assertEquals("abc", bufferedSource.readUtf8(3))
  }

  @Test
  fun sinkFlushFailsIfReaderIsClosedBeforeAllDataIsRead() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.source.close()
    try {
      pipe.sink.flush()
      fail()
    } catch (expected: IOException) {
      assertEquals("source is closed", expected.message)
    }
  }

  @Test
  fun sinkCloseFailsIfReaderIsClosedBeforeAllDataIsRead() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.source.close()
    try {
      pipe.sink.close()
      fail()
    } catch (expected: IOException) {
      assertEquals("source is closed", expected.message)
    }
  }

  @Test
  fun sinkClose() {
    val pipe = Pipe(100L)
    pipe.sink.close()
    try {
      pipe.sink.write(Buffer().writeUtf8("abc"), 3)
      fail()
    } catch (expected: IllegalStateException) {
      assertEquals("closed", expected.message)
    }
    try {
      pipe.sink.flush()
      fail()
    } catch (expected: IllegalStateException) {
      assertEquals("closed", expected.message)
    }
  }

  @Test
  fun sinkMultipleClose() {
    val pipe = Pipe(100L)
    pipe.sink.close()
    pipe.sink.close()
  }

  @Test
  fun sinkCloseDoesntWaitForSourceRead() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.sink.close()
    val bufferedSource = pipe.source.buffer()
    assertEquals("abc", bufferedSource.readUtf8())
    assertTrue(bufferedSource.exhausted())
  }

  @Test
  fun sourceClose() {
    val pipe = Pipe(100L)
    pipe.source.close()
    try {
      pipe.source.read(Buffer(), 3)
      fail()
    } catch (expected: IllegalStateException) {
      assertEquals("closed", expected.message)
    }
  }

  @Test
  fun sourceMultipleClose() {
    val pipe = Pipe(100L)
    pipe.source.close()
    pipe.source.close()
  }

  @Test
  fun sourceReadUnblockedByClosedSink() {
    val pipe = Pipe(3L)
    executorService.schedule(
      {
        pipe.sink.close()
      },
      1000,
      TimeUnit.MILLISECONDS,
    )
    val start = now()
    val readBuffer = Buffer()
    assertEquals(-1, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals(0, readBuffer.size)
    assertElapsed(1000.0, start)
  }

  /**
   * The writer has 12 bytes to write. It alternates sleeping 1000 ms, then writing 3 bytes. The
   * reader is reading as fast as it can. That should make for an approximate timeline like this:
   *
   * ```
   *    0: writer sleeps until 1000
   *    0: reader blocks
   * 1000: writer writes 'abc', sleeps until 2000
   * 1000: reader reads 'abc'
   * 2000: writer writes 'def', sleeps until 3000
   * 2000: reader reads 'def'
   * 3000: writer writes 'ghi', sleeps until 4000
   * 3000: reader reads 'ghi'
   * 4000: writer writes 'jkl', returns
   * 4000: reader reads 'jkl', returns
   * ```
   */
  @Test
  fun sourceBlocksOnSlowWriter() {
    val pipe = Pipe(100L)
    executorService.execute {
      Thread.sleep(1000L)
      pipe.sink.write(Buffer().writeUtf8("abc"), 3)
      Thread.sleep(1000L)
      pipe.sink.write(Buffer().writeUtf8("def"), 3)
      Thread.sleep(1000L)
      pipe.sink.write(Buffer().writeUtf8("ghi"), 3)
      Thread.sleep(1000L)
      pipe.sink.write(Buffer().writeUtf8("jkl"), 3)
    }
    val start = now()
    val readBuffer = Buffer()
    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("abc", readBuffer.readUtf8())
    assertElapsed(1000.0, start)
    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("def", readBuffer.readUtf8())
    assertElapsed(2000.0, start)
    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("ghi", readBuffer.readUtf8())
    assertElapsed(3000.0, start)
    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("jkl", readBuffer.readUtf8())
    assertElapsed(4000.0, start)
  }

  /** Returns the nanotime in milliseconds as a double for measuring timeouts. */
  private fun now(): Double {
    return System.nanoTime() / 1000000.0
  }

  /**
   * Fails the test unless the time from start until now is duration, accepting differences in
   * -50..+450 milliseconds.
   */
  private fun assertElapsed(duration: Double, start: Double) {
    assertEquals(duration, now() - start - 200.0, 250.0)
  }
}
